#include <bits/stdc++.h>
using namespace std;
#define int long long

void solve()
{
    int n,m;
    cin>>n>>m;
    vector<int>v(n+1),w(m+1),f(n+2);
    for(int i=1;i<=n;i++)cin>>v[i];
    for(int i=1;i<=m;i++)cin>>w[i];
    sort(v.begin()+1,v.end());
    sort(w.begin()+1,w.end());
    for(int i=n;i>=1;i--)f[i]=f[i+1]+v[i];
    int ans=0;
    for(int i=1;i<=m;i++)
    {
        int l=0,r=n,mid;
        while(l<r)
        {
            mid=(r-l)/2+l+1;
            if(v[mid]<=w[i])l=mid;
            else r=mid-1;
        }
        ans+=w[i]*l+f[l+1];
    }
    cout<<ans;
}
signed main()
{
    //ios::sync_with_stdio(false);
    //cin.tie(nullptr);
    int T=1;
    //cin>>T;
    while(T--)
    {
       solve();
    }
    return 0;
}